{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LALE cross validation example\n",
    "This example is in collaboration with the [LALE team](https://github.com/IBM/lale), which demonstrates how a LALE pipeline can be translated into a CodeFlare pipeline, targeting cross validation.\n",
    "\n",
    "It assumes that LALE is available and installed in your local environment.\n",
    "\n",
    "One can see from running this notebook that LALE cross validation is single threaded and takes ~10minutes on a laptop (depending on the configuration), whereas using CodeFlare pipelines running on Ray, this time is reduced to around 75 seconds (with 8x parallelism) for the 10 fold cross validation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment below to install lale for running this notebook\n",
    "\n",
    "# !pip install lale\n",
    "# !pip install 'liac-arff>=2.4.0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from lale.datasets.openml import fetch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "(X_train, y_train), (X_test, y_test) = fetch(\"jungle_chess_2pcs_raw_endgame_complete\", \"classification\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First, we will show how this data can be used to do cross validation using a simple pipeline with random forest\n",
    "from lale.lib.sklearn import PCA, Nystroem, SelectKBest, RandomForestClassifier\n",
    "from lale.lib.lale import ConcatFeatures\n",
    "\n",
    "pipeline = (PCA() & Nystroem() & SelectKBest(k=3)) >> ConcatFeatures() >> RandomForestClassifier(n_estimators=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 7min 18s, sys: 8.91 s, total: 7min 27s\n",
      "Wall time: 6min 59s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.8161838161838162,\n",
       " 0.8105228105228105,\n",
       " 0.8148518148518149,\n",
       " 0.8218448218448219,\n",
       " 0.8208458208458208,\n",
       " 0.8111888111888111,\n",
       " 0.8105228105228105,\n",
       " 0.8181818181818182,\n",
       " 0.8011325782811459,\n",
       " 0.8121252498334444]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "from lale.helpers import cross_val_score\n",
    "cross_val_score(pipeline, X_train, y_train, cv=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-06-19 21:02:55,145\tINFO services.py:1269 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8266\u001b[39m\u001b[22m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'node_ip_address': '9.211.53.245',\n",
       " 'raylet_ip_address': '9.211.53.245',\n",
       " 'redis_address': '9.211.53.245:29680',\n",
       " 'object_store_address': '/tmp/ray/session_2021-06-19_21-02-53_442708_86627/sockets/plasma_store',\n",
       " 'raylet_socket_name': '/tmp/ray/session_2021-06-19_21-02-53_442708_86627/sockets/raylet',\n",
       " 'webui_url': '127.0.0.1:8266',\n",
       " 'session_dir': '/tmp/ray/session_2021-06-19_21-02-53_442708_86627',\n",
       " 'metrics_export_port': 61863,\n",
       " 'node_id': 'eff8f3d5558252aa2d18c57b62891d1a169a7404979a63b6c9882a14'}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Start Ray and init\n",
    "\n",
    "import ray\n",
    "ray.init(object_store_memory=16 * 1024 * 1024 * 1024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import KFold, StratifiedKFold\n",
    "kf = StratifiedKFold(n_splits=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import FeatureUnion\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.kernel_approximation import Nystroem\n",
    "from sklearn.feature_selection import SelectKBest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_union = FeatureUnion(transformer_list=[('PCA', PCA()),\n",
    "                                                ('Nystroem',\n",
    "                                                 Nystroem()),\n",
    "                                                ('SelectKBest',\n",
    "                                                 SelectKBest(k=3))])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_forest = RandomForestClassifier(n_estimators=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import codeflare.pipelines.Datamodel as dm\n",
    "\n",
    "# Create the CF pipeline and the nodes, add the edge\n",
    "pipeline = dm.Pipeline()\n",
    "node_fu = dm.EstimatorNode('feature_union', feature_union)\n",
    "node_rf = dm.EstimatorNode('randomforest', random_forest)\n",
    "\n",
    "pipeline.add_edge(node_fu, node_rf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import codeflare.pipelines.Runtime as rt\n",
    "from codeflare.pipelines.Runtime import ExecutionType"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline_input = dm.PipelineInput()\n",
    "xy = dm.Xy(X_train, y_train)\n",
    "pipeline_input.add_xy_arg(node_fu, xy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3.06 s, sys: 2.11 s, total: 5.17 s\n",
      "Wall time: 1min 15s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "scores = rt.cross_validate(kf, pipeline, pipeline_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.8161838161838162,\n",
       " 0.8118548118548119,\n",
       " 0.8145188145188145,\n",
       " 0.8198468198468198,\n",
       " 0.8175158175158175,\n",
       " 0.8111888111888111,\n",
       " 0.8158508158508159,\n",
       " 0.8171828171828172,\n",
       " 0.8037974683544303,\n",
       " 0.8087941372418388]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "ray.shutdown()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
